package com.hubspot.singularity.s3.base;

import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.BasicAWSCredentials;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.AmazonS3Client;
import com.amazonaws.services.s3.model.S3Object;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.hubspot.deploy.S3Artifact;
import com.hubspot.mesos.JavaUtils;
import com.hubspot.singularity.runner.base.sentry.SingularityRunnerExceptionNotifier;
import com.hubspot.singularity.s3.base.config.SingularityS3Configuration;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import java.nio.channels.FileChannel;
import java.nio.channels.WritableByteChannel;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.EnumSet;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.slf4j.Logger;

public class S3ArtifactDownloader {
  private final Logger log;
  private final SingularityS3Configuration configuration;
  private final SingularityRunnerExceptionNotifier exceptionNotifier;

  public S3ArtifactDownloader(
    SingularityS3Configuration configuration,
    Logger log,
    SingularityRunnerExceptionNotifier exceptionNotifier
  ) {
    this.configuration = configuration;
    this.log = log;
    this.exceptionNotifier = exceptionNotifier;
  }

  public void download(S3Artifact s3Artifact, Path downloadTo) {
    final long start = System.currentTimeMillis();
    boolean success = false;

    try {
      downloadThrows(s3Artifact, downloadTo);
      success = true;
    } catch (Throwable t) {
      throw new RuntimeException(t);
    } finally {
      log.info(
        "S3 Download {}/{} finished {} after {}",
        s3Artifact.getS3Bucket(),
        s3Artifact.getS3ObjectKey(),
        success ? "successfully" : "with error",
        JavaUtils.duration(start)
      );
    }
  }

  private BasicAWSCredentials getCredentialsForBucket(String bucketName) {
    if (configuration.getS3BucketCredentials().containsKey(bucketName)) {
      return configuration.getS3BucketCredentials().get(bucketName).toAWSCredentials();
    }

    return new BasicAWSCredentials(
      configuration.getS3AccessKey().get(),
      configuration.getS3SecretKey().get()
    );
  }

  private void downloadThrows(final S3Artifact s3Artifact, final Path downloadTo)
    throws Exception {
    log.info("Downloading {}", s3Artifact);

    ClientConfiguration clientConfiguration = new ClientConfiguration()
    .withSocketTimeout(configuration.getS3ChunkDownloadTimeoutMillis());
    if (configuration.isS3UseV2Signing()) {
      clientConfiguration.setSignerOverride("S3SignerType");
    }

    final AmazonS3 s3Client = new AmazonS3Client(
      getCredentialsForBucket(s3Artifact.getS3Bucket()),
      clientConfiguration
    );

    if (configuration.getS3Endpoint().isPresent()) {
      s3Client.setEndpoint(configuration.getS3Endpoint().get());
    }

    long length = 0;

    if (s3Artifact.getFilesize().isPresent()) {
      length = s3Artifact.getFilesize().get();
    } else {
      S3Object details = s3Client.getObject(
        s3Artifact.getS3Bucket(),
        s3Artifact.getS3ObjectKey()
      );

      Preconditions.checkNotNull(
        details,
        "Couldn't find object at %s/%s",
        s3Artifact.getS3Bucket(),
        s3Artifact.getS3ObjectKey()
      );

      length = details.getObjectMetadata().getContentLength();
    }

    int numChunks = (int) (length / configuration.getS3ChunkSize());

    if (length % configuration.getS3ChunkSize() > 0) {
      numChunks++;
    }

    final long chunkSize = length / numChunks + (length % numChunks);

    log.info(
      "Downloading {}/{} in {} chunks of {} bytes to {}",
      s3Artifact.getS3Bucket(),
      s3Artifact.getS3ObjectKey(),
      numChunks,
      chunkSize,
      downloadTo
    );

    final ExecutorService chunkExecutorService = Executors.newFixedThreadPool(
      numChunks,
      new ThreadFactoryBuilder()
        .setDaemon(true)
        .setNameFormat("S3ArtifactDownloaderChunkThread-%d")
        .build()
    );
    final List<Future<Path>> futures = Lists.newArrayListWithCapacity(numChunks);

    for (int chunk = 0; chunk < numChunks; chunk++) {
      futures.add(
        chunkExecutorService.submit(
          new S3ArtifactChunkDownloader(
            configuration,
            log,
            s3Client,
            s3Artifact,
            downloadTo,
            chunk,
            chunkSize,
            length,
            exceptionNotifier
          )
        )
      );
    }

    long remainingMillis = configuration.getS3DownloadTimeoutMillis();
    boolean failed = false;

    for (int chunk = 0; chunk < numChunks; chunk++) {
      final Future<Path> future = futures.get(chunk);

      if (failed) {
        future.cancel(true);
        continue;
      }

      final long start = System.currentTimeMillis();

      if (!handleChunk(s3Artifact, future, downloadTo, chunk, start, remainingMillis)) {
        failed = true;
      }

      remainingMillis -= (System.currentTimeMillis() - start);
    }

    chunkExecutorService.shutdownNow();

    Preconditions.checkState(
      !failed,
      "Downloading %s/%s failed",
      s3Artifact.getS3Bucket(),
      s3Artifact.getS3ObjectKey()
    );
  }

  private boolean handleChunk(
    S3Artifact s3Artifact,
    Future<Path> future,
    Path downloadTo,
    int chunk,
    long start,
    long remainingMillis
  ) {
    if (remainingMillis <= 0) {
      remainingMillis = 1;
    }

    try {
      Path path = future.get(remainingMillis, TimeUnit.MILLISECONDS);

      if (chunk > 0) {
        combineChunk(downloadTo, path);
      }

      return true;
    } catch (TimeoutException te) {
      log.error(
        "Chunk {} for {} timed out after {} - had {} remaining",
        chunk,
        s3Artifact.getFilename(),
        JavaUtils.duration(start),
        JavaUtils.durationFromMillis(remainingMillis)
      );
      future.cancel(true);
      exceptionNotifier.notify(
        "TimeoutException during download",
        te,
        ImmutableMap.of(
          "filename",
          s3Artifact.getFilename(),
          "chunk",
          Integer.toString(chunk)
        )
      );
    } catch (Throwable t) {
      log.error(
        "Error while handling chunk {} for {}",
        chunk,
        s3Artifact.getFilename(),
        t
      );
      exceptionNotifier.notify(
        String.format("Error handling chunk (%s)", t.getMessage()),
        t,
        ImmutableMap.of(
          "filename",
          s3Artifact.getFilename(),
          "chunk",
          Integer.toString(chunk)
        )
      );
    }

    return false;
  }

  @SuppressFBWarnings(
    value = "RCN_REDUNDANT_NULLCHECK_WOULD_HAVE_BEEN_A_NPE",
    justification = "https://github.com/spotbugs/spotbugs/issues/259"
  )
  private void combineChunk(Path downloadTo, Path path) throws Exception {
    final long start = System.currentTimeMillis();
    long bytes = 0;

    log.info("Writing {} to {}", path, downloadTo);

    try (
      WritableByteChannel wbs = Files.newByteChannel(
        downloadTo,
        EnumSet.of(StandardOpenOption.APPEND, StandardOpenOption.WRITE)
      )
    ) {
      try (
        FileChannel readChannel = FileChannel.open(
          path,
          EnumSet.of(StandardOpenOption.READ, StandardOpenOption.DELETE_ON_CLOSE)
        )
      ) {
        bytes = readChannel.size();
        readChannel.transferTo(0, bytes, wbs);
      }
    }

    log.info("Finished writing {} bytes in {}", bytes, JavaUtils.duration(start));
  }
}
